import matplotlib.pyplot as plt
from matplotlib.pyplot import MultipleLocator
import numpy as np
import math
import sys,os

class Dijkstar:
    
    def __init__(self,width,height,gridSize,robot_location,target_location):
    # 初始化
        self.Width = width  # 宽度的栅格数
        self.Height = height
        self.gridSize = gridSize    # 栅格的大小
        self.grid = None    # 栅格空间 类型 np.zeros(width,height)
        self.obs = None     # 障碍物坐标
        self.robotLoc = robot_location  # 机器人起始在栅格的位置
        self.targetLoc = target_location    # 目标点位
        self.ENABLE_DRAW = True     # 是否允许绘制动图
        self.creatGrid()    # 创建栅格空间
        self.creatObstacle()    # 创建障碍物坐标，并绘制
        self.writeObstacle()    # 
        self.planning() # 规划轨迹
        
        if self.ENABLE_DRAW:
            plt.show()
        
    
    def creatGrid(self):
        # 创建栅格
        self.grid = np.zeros(((int)(self.Height), (int)(self.Width)))   # 初始化栅格空间
        plt.axis('scaled')  # 横纵坐标大小刻度一致     
        axis = plt.gca()    # 获取坐标轴的点
        xAxis = np.arange(0,(self.Width+1)*self.gridSize,self.gridSize) # 设置栅格
        yAxis = np.arange(0,(self.Height+1)*self.gridSize,self.gridSize)
        axis.set_xticks(xAxis)
        axis.set_yticks(yAxis)
        axis.tick_params(axis = 'both', direction = 'in')
        plt.grid(True)  # 绘制栅格

        # 绘制机器人坐标 和 终点坐标
        self.drawPoint(self.robotLoc,markersize = 4, marker = 'o', markerfacecolor = 'red', markeredgecolor = 'red')
        self.drawPoint(self.targetLoc,markersize = 4, marker = 'v', markerfacecolor = 'green', markeredgecolor = 'green')
        


    def creatObstacle(self):
        # 创建障碍物 
        self.obs = []

        
        obs1 = [[0,a] for a in range((int)(self.Height))]
        obs2 = [[(int)(self.Width) - 1,a] for a in range((int)(self.Height))]
        obs3 = [[a,0] for a in range((int)(self.Width))]
        obs4 = [[a,(int)(self.Height) - 1] for a in range((int)(self.Width))]
        obs5 = [[10,a] for a in range(35)] 
        obs6 = [[20,39 - a] for a in range(35)]
        obs7 = [[30,a] for a in range(35)]
        # 创建障碍物
        self.obs = self.obs + obs1 + obs2 + obs3 + obs4 + obs5  + obs7 # + obs6 + obs8 + obs9


        for item in list(self.obs):
            self.drawPoint(item, marker = 's', markerfacecolor = 'black', markeredgecolor = 'black')

    def writeObstacle(self):
        # 将障碍物 写进 栅格（相同的坐标置 1）
        print(self.grid)
        for item in list(self.obs):
            self.grid[item[0],item[1]] = 1;
        print(self.grid)

    def drawPoint(self,loction,markersize = 4,marker = 's',markerfacecolor = 'black',markeredgecolor = 'black'):
        # 将点绘制进栅格（因为栅格用刻度表示，所以绘制点时需要加上偏置）
        plt.plot(loction[0]*self.gridSize + self.gridSize/2,loction[1]*self.gridSize + self.gridSize/2,
                 linestyle = '',
                 markersize = 4,
                 marker = marker,
                 markerfacecolor = markerfacecolor,
                 markeredgecolor = markeredgecolor)
        

    class Node:
        # 节点类   
        def __init__(self,pos,cost,parent_pos):
            self.pos = pos
            self.cost = cost
            self.parent_pos = parent_pos

        def __str__(self):
            return str(pos) + "," + str(
                self.cost) + "," + str(self.parent_index)



    def planning(self):
    # 路径规划函数
        start_node = self.Node(self.robotLoc,0,[-1,-1])
        end_node = self.Node(self.targetLoc,0,[-1,-1])
        open_set, closed_set = dict(),dict()
        open_set[self.robotLoc] = start_node

        while True:

            minCostPos = min(open_set, key=lambda o: open_set[o].cost)  # 取Open_set 中 cost最小的节点的坐标
            # key=lambda o: open_set[o].cost      lambda是一个隐函数，是固定写法,等价于 def key(o): return open[o].cost
            # print(minCostPos)
            currentPos = minCostPos    # 当前收录节点是最小的节点
            closed_set[currentPos] = open_set[minCostPos]   # 收录的节点增加
            
            del open_set[minCostPos]    # 删除刚刚从 open_set 中收录的节点
                
            if currentPos == self.targetLoc:    # 如果当前点位的坐标等于终点 结束循环
                goal_node = self.Node(currentPos,closed_set[currentPos].cost,closed_set[currentPos].parent_pos) # 最终的节点设置，也可以直接等于closed_set[currentPos]
                break
            self.findNeiborsAndRecord(currentPos,open_set,closed_set)   # 找出当前的节点的临近节点，并更新距离

            if self.ENABLE_DRAW:    # 如果允许绘制
                self.drawPoint(currentPos, marker = 'x', markerfacecolor = 'blue', markeredgecolor = 'blue')    # 绘制出当前的节点
                if len(closed_set.keys()) % 80 == 0:
                    plt.pause(0.001)

        if self.ENABLE_DRAW:
            rx, ry = self.final_path(goal_node, closed_set)
            plt.plot(rx,ry,color = 'red',linewidth = 2)
    
    # 找出当前节点的临近节点
    def findNeiborsAndRecord(self,pos,open_set,closed_set):# pos 是当前节点
        motion = self.get_motion_model()   # 机器人移动类型 前、后、左、右、左前、左后、右前、右后。
        for item in motion:    # 获得n个临近节点，与机器人的运动形式有关
            # print(item[0:2])
            p = (pos[0] + item[0], pos[1] + item[1])    # 合成坐标

            if p[0] < 0 or p[1] < 0 or self.grid[p[0]][p[1]] == 1 or p in closed_set:   # 如果在栅格外 或者 在障碍物里 或者 已经被close_set收录
                continue
            
            if p not in open_set:   # 不在open_set里，则放进open_set 里
                node = self.Node(p,item[2] + closed_set[pos].cost,pos)  # 
                open_set[p] = node
            else:    # 如果在open_set里面，需要更新距离
                if open_set[p].cost > item[2] + closed_set[pos].cost:
                      node = self.Node(p,item[2] + closed_set[pos].cost,pos)
                      open_set[p] = node           
        

    def final_path(self, goal_node, closed_set):
    # 计算最终的路径，根据 节点的上一节点 进行遍历，寻找到第一个节点
        a = goal_node
        rx, ry = [],[]
        while a.parent_pos != self.robotLoc:
            rx.append(a.pos[0] * self.gridSize + self.gridSize/2)
            ry.append(a.pos[1] * self.gridSize + self.gridSize/2)
            a = closed_set[a.parent_pos]
        rx.append(self.robotLoc[0] * self.gridSize + self.gridSize/2)
        ry.append(self.robotLoc[1] * self.gridSize + self.gridSize/2)
        return rx, ry   # 返回节点的值


    @staticmethod
    def get_motion_model():
        # dx, dy, cost
        motion = [[1, 0, 1],
                  [0, 1, 1],
                  [-1, 0, 1],
                  [0, -1, 1],
                  [-1, -1, math.sqrt(2)],
                  [-1, 1, math.sqrt(2)],
                  [1, -1, math.sqrt(2)],
                  [1, 1, math.sqrt(2)]]
        return motion


def wait_key():
    ''' Wait for a key press on the console and return it. '''
    result = None
    if os.name == 'nt':
        import msvcrt
        result = msvcrt.getch()
    else:
        import termios
        fd = sys.stdin.fileno()
 
        oldterm = termios.tcgetattr(fd)
        newattr = termios.tcgetattr(fd)
        newattr[3] = newattr[3] & ~termios.ICANON & ~termios.ECHO
        termios.tcsetattr(fd, termios.TCSANOW, newattr)
 
        try:
            result = sys.stdin.read(1)
        except IOError:
            pass
        finally:
            termios.tcsetattr(fd, termios.TCSAFLUSH, oldterm)
 
    return result
def main():

    start_pos = (5,5)
    end_pos = (37,5)
    grid_width = 40
    grid_height = 40
    min_grid = 5
    plt.axis('scaled')  # 横纵坐标大小刻度一致
    axis = plt.gca()    # 获取坐标轴的点
    xAxis = np.arange(0,41*5,5) # 设置栅格
    yAxis = np.arange(0,41*5,5)
    axis.set_xticks(xAxis)
    axis.set_yticks(yAxis)
    axis.tick_params(axis = 'both', direction = 'in')
    plt.grid(True)  # 绘制栅格
    plt.ioff()
    plt.gcf().show()

    wait_key()
    dijkstar = Dijkstar(grid_width,grid_height,min_grid,start_pos,end_pos)

    return
    
if __name__ == '__main__':
    main()